31. 客户购买意愿分析(K近邻算法)

本节学习目标

  • 理解K近邻(KNN)算法的核心原理
  • 掌握欧氏距离的计算方法
  • 学会使用 scikit-learn 构建KNN分类模型
  • 理解网格搜索与交叉验证调优流程
  • 能够利用混淆矩阵评估分类模型性能

K近邻算法(KNN)是什么?

KNN 是一种经典的监督学习分类算法

  • 核心思想:‘物以类聚,人以群分’
  • 决策方式:根据K个最近邻居的类别进行多数投票
  • 不需要显式的模型训练过程,属于懒惰学习(Lazy Learning)

典型商业应用

  • 客户分类与画像
  • 信用评分与风险评估
  • 精准营销中的购买意愿预测

KNN算法原理:找到最近的K个邻居

决策规则

  1. 给定一个待预测样本 \(x\)
  2. 在训练集中找到与 \(x\) 距离最近\(K\) 个样本
  3. \(K\) 个邻居中出现次数最多的类别,即为 \(x\) 的预测类别

关键问题:如何衡量’距离’?

距离度量:欧氏距离

最常用的距离度量是欧氏距离(Euclidean Distance):

\[d(x,y) = \sqrt{\sum_{i=1}^{n}(x_i - y_i)^2}\]

  • \(x_i, y_i\):两个样本在第 \(i\) 个特征上的取值
  • \(n\):特征总数

直观理解:就是多维空间中两点之间的直线距离

K值的选择至关重要

K值过小 K值过大
模型过拟合 模型欠拟合
对噪声敏感 决策边界过于平滑
分类不稳定 丢失局部特征

最佳实践:通过交叉验证自动搜索最优K值

案例背景:社交网络广告客户购买预测

业务场景

  • 某社交网络平台投放广告后,收集用户数据
  • 目标:根据用户特征预测是否会购买产品

数据特征

特征 说明
Age 用户年龄
EstimatedSalary 预估年薪
Purchased 是否购买(0/1)

⭐ 平台任务代码

Listing 1
# 注:04_Social_Network_Ads.csv数据文件本地没有,但平台已经内置
# ⚠️ 平台原始代码 - 请原样输入至教学平台(注释除外),平台才会判定答案正确
# 导入相关模块
import pandas as pd
import numpy as np  # 导入NumPy数值计算库
from sklearn.model_selection import train_test_split  # 导入Scikit-learn的train_test_split模块
from sklearn.preprocessing import StandardScaler  # 导入Scikit-learn的StandardScaler模块
from sklearn.neighbors import KNeighborsClassifier  # 导入Scikit-learn的KNeighborsClassifier模块
from sklearn.model_selection import GridSearchCV  # 导入Scikit-learn的GridSearchCV模块
from sklearn.metrics import confusion_matrix  # 导入Scikit-learn的confusion_matrix模块
 
# 导入数据集
data = pd.read_csv("04_Social_Network_Ads.csv")
 
# 数据集划分
x = data[["Age","EstimatedSalary"]]
y = data["Purchased"]  # 提取Purchased列作为y变量
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.30, random_state=0)
 
# 数据标准化
transfer1 = StandardScaler()
x_train = transfer1.fit_transform(x_train)  # 对数据进行变换
x_test = transfer1.transform(x_test)  # 对数据进行变换
 
# 训练模型
estimator = KNeighborsClassifier(algorithm='kd_tree')
 
# 模型选择与调优——网格搜索和交叉验证
# 准备要调的超参数
param_dict = {"n_neighbors": [1, 3, 5, 7, 9, 11, 13]}
estimator = GridSearchCV(estimator, param_grid=param_dict, cv=4)  # 创建网格搜索交叉验证对象,自动寻找最优超参数
estimator.fit(x_train,y_train)  # 在数据上训练estimator模型
 
# 模型评估
y_pre = estimator.predict(x_test)
print("预测结果:\n", y_pre)  # 输出预测结果:\n
print("准确率为:\n", estimator.score(x_test, y_test))  # 输出准确率为:\n
# print("对比真实值和预测值:\n", y_test==y_pre)
print("在交叉验证中最好的结果为:\n", estimator.best_score_)
print("使用网格搜索的最好的模型:\n", estimator.best_estimator_)  # 输出使用网格搜索的最好的模型:\n

代码解读:步骤一 — 数据导入与划分

数据导入:使用 pd.read_csv() 读取CSV文件

特征与标签分离

  • x:Age 和 EstimatedSalary(特征矩阵)
  • y:Purchased(目标变量)

数据集划分

  • 训练集占 70%,测试集占 30%
  • random_state=0 保证结果可复现

代码解读:步骤二 — 数据标准化

为什么需要标准化?

  • 年龄范围:约 18~60
  • 薪资范围:约 15,000~150,000
  • 尺度差异巨大,距离计算会被薪资主导

StandardScaler 的作用

  • 将每个特征转换为均值为0、标准差为1的分布
  • fit_transform:在训练集上拟合并转换
  • transform:用训练集的参数转换测试集(避免数据泄露)

代码解读:步骤三 — 网格搜索与交叉验证

GridSearchCV 自动完成超参数调优:

  • 参数网格:K = [1, 3, 5, 7, 9, 11, 13]
  • 4折交叉验证:训练集分为4份,轮流验证
  • 自动找到使准确率最高的K值

代码解读:步骤四 — 模型评估

评估指标

  • 准确率(Accuracy):预测正确的比例
  • 混淆矩阵(Confusion Matrix):详细展示分类结果
Listing 3
# 注:该代码块依赖的数据来自上方平台任务代码块,因其未执行,本块也无法执行

# ==================== 预测测试集 ====================
# 使用最优模型对测试集进行预测
y_pred = grid_search.predict(X_test_scaled)  # 输出预测类别(0或1)

# ==================== 计算准确率 ====================
accuracy = accuracy_score(y_test, y_pred)  # 计算预测准确率
print(f'测试集准确率: {accuracy:.4f}')  # 打印准确率,范围[0,1]

# ==================== 生成混淆矩阵 ====================
# 混淆矩阵展示了预测结果与真实情况的对比
cm = confusion_matrix(y_test, y_pred)  # 生成2x2混淆矩阵
print(f'\n混淆矩阵:')
print(cm)
# 混淆矩阵格式:
#           预测0  预测1
# 实际0    [TN    [FP
# 实际1    [FN    [TP

# ==================== 可视化混淆矩阵 ====================
import matplotlib.pyplot as plt  # 导入绘图库
import seaborn as sns  # 导入统计绘图库

plt.figure(figsize=(8, 6))  # 创建8x6英寸的画布
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')  # 绘制热力图,annot显示数值,fmt整数格式
plt.title('混淆矩阵', fontsize=14)  # 设置标题
plt.xlabel('预测类别', fontsize=12)  # x轴标签
plt.ylabel('真实类别', fontsize=12)  # y轴标签
plt.tight_layout()  # 自动调整布局,避免标签重叠
plt.show()  # 显示图形

# ==================== 输出最优模型 ====================
print(f'\n最佳模型: {grid_search.best_estimator_}')  # 打印最优模型的完整配置

如何理解混淆矩阵?

预测:不购买(0) 预测:购买(1)
实际:不购买(0) TN(正确拒绝) FP(误报)
实际:购买(1) FN(漏报) TP(正确识别)
  • TN:正确预测为不购买
  • TP:正确预测为购买
  • FP:实际不买,误判为买(浪费营销资源)
  • FN:实际会买,漏判为不买(错失客户)

本节总结

步骤 操作 关键要点
数据准备 特征提取与集合划分 70/30 分割比例
数据预处理 StandardScaler标准化 消除特征尺度差异
模型训练 KNN + GridSearchCV 自动搜索最优K值
模型评估 准确率 + 混淆矩阵 多角度评估性能

核心收获:KNN简单直观,但特征标准化和K值选择是成败关键。